#include "statsxx/machine_learning/NeuralNet.hpp"

// STL
#include <cmath>
#include <iostream>
#include <limits>                  // std::numeric_limits<>
#include <tuple>                   // std::tuple, std::make_tuple

// jScience
#include "jScience/stl/vector.hpp" // std::vector *=

// stats++
#include "statsxx/statistics.hpp"  // statistics::mean(), ::stddev()


//
// DESC: Parses a data set through the network.
//
// INPUT:
//          DataSet ds      : Data set to parse
//
// OUTPUT
//          double          : Mean error function (non-classification: sum-of-squares error, classification: cross-entropy error)
//          double          : Standard deviation of error function
//          std::vector<std::vector<double>> : Output for each data point in the data set
//
// NOTES:
//     ! IMPORTANT: train() should have been called at least once prior to this, in order to store min/max/etc values
//     ! Error functions:
//          E(w) = (1/2)*sum_{p=1,P}*sum_{n=1,N}[y_pn - t_pn]^2
//          E(w) = -sum_{p=1,P}[t_p*ln(y_p) + (1 - t_p)*ln(1 - y_p)]
//          E(w) = -sum_{p=1,P}*sum_{n=1,N}[t_pn*ln(y_pn)]
//          where:
//               w = weights
//               p = pattern ([1,...,P]
//               n = output node ([1,...,N]
//               t_pn = target output of pattern p and node n
//               y_pn = calculated output of pattern p and node n
//
inline std::tuple<double, double, std::vector<std::vector<double>>> NEURAL_NET::parse_TS(DataSet ds)
{
    std::vector<double> err;
    std::vector<std::vector<double>> output;

    // << JMM: is this really required? because of the ln() in classification problems, we have to prevent against infinities -- so the limit of 0*ln(0) is manually implemented and 1*ln(0) = a very large number >>
    double very_large = 100.0;
    //double very_large = std::numeric_limits<double>::max()/TS.pt.size();
    double very_small = std::exp(-very_large);

    for( const auto &pt : ds.pt )
    {
        std::vector<double> tmp_out;
        this->evaluate( pt.in, tmp_out );
        output.push_back(tmp_out);

        for(int i = 0; i < m_nout; ++i)
        {
            if( m_isClassif )
            {
                if( m_nout > 1 ) // multinomial
                {
                    if( pt.out[i] != 0 )
                    {
                        if( tmp_out[i] < very_small )
                        {
                            err.push_back( very_large );
                        }
                        else
                        {
                            err.push_back( -std::log(tmp_out[i]) );
                        }
                    }
                }
                else             // binomial classification
                {
                    if( pt.out[i] == 0 )
                    {
                        if( (1.0 - tmp_out[i]) < very_small )
                        {
                            err.push_back( very_large );
                        }
                        else
                        {
                            err.push_back( -std::log(1.0 - tmp_out[i]) );
                        }
                    }
                    else
                    {
                        if( tmp_out[i] < very_small )
                        {
                            err.push_back( very_large );
                        }
                        else
                        {
                            err.push_back( -std::log(tmp_out[i]) );
                        }
                    }

                    //E += ( -pt.out[i]*log(tmp_out[i]) - (1.0 - pt.out[i])*log(1.0 - tmp_out[i]) );
                }
            }
            else
            {
                // ??? because the output is naturally normalized, different types of output are similarly weighted
                err.push_back( (tmp_out[i] - pt.out[i])*(tmp_out[i] - pt.out[i]) );
            }
        }
    }

    if( !m_isClassif )
    {
        // a 1/2 is often (always?) applied here in order to simplify the error function's derivative
        err *= 0.5;
    }

    return std::make_tuple(
                           statistics::mean(err),
                           statistics::stddev(err),
                           output
                           );
}
